use std::{any::Any, collections::HashMap};

use impl_downcast::impl_downcast;

use crate::{Context, Material};

pub trait Pass {
    fn init(&mut self, ctx: &Context);
    fn draw(&mut self, ctx: &Context, screen_view: &wgpu::TextureView, encoder: &mut wgpu::CommandEncoder, scene: &Scene);
}

pub trait Brick: Any {
    fn prepare(&mut self, ctx: &super::Context);
    fn render<'pass>(&'pass self, pass: &mut wgpu::RenderPass<'pass>);
}

impl_downcast!(Brick);

pub struct Scene {
    clear_color: wgpu::Color,
    globals: HashMap<u32, Material>,
    bricks: HashMap<String, Box<dyn Brick>>,
}

impl Scene {
    pub fn new(clear_color: wgpu::Color) -> Self {
        Self {
            clear_color,
            globals: HashMap::new(),
            bricks: HashMap::new(),
        }
    }

    pub fn put_global(&mut self, bind_point: u32, mat: Material) {
        self.globals.insert(bind_point, mat);
    }

    pub fn global(&self, bind_point: u32) -> Option<&Material> {
        self.globals.get(&bind_point)
    }

    pub fn global_mut(&mut self, bind_point: u32) -> Option<&mut Material> {
        self.globals.get_mut(&bind_point)
    }

    pub fn clear_color(&self) -> wgpu::Color {
        self.clear_color
    }

    pub fn get_brick<B: Brick>(&mut self, name: &str) -> Option<&B> {
        self.bricks.get(name).and_then(|b| b.downcast())
    }

    pub fn get_brick_mut<B: Brick>(&mut self, name: &str) -> Option<&mut B> {
        self.bricks.get_mut(name).and_then(|b| b.downcast_mut())
    }

    pub fn add_brick<B: Brick>(&mut self, name: impl Into<String>, brick: B) {
        self.bricks.insert(name.into(), Box::new(brick));
    }

    pub fn prepare(&mut self, ctx: &super::Context) {
        for global in self.globals.values_mut() {
            global.prepare(ctx)
        }

        for brick in self.bricks.values_mut() {
            brick.prepare(ctx)
        }
    }

    pub fn render<'pass>(&'pass self, pass: &mut wgpu::RenderPass<'pass>) {
        for (bind_point, global) in self.globals.iter() {
            pass.set_bind_group(*bind_point, global.group(), &[])
        }

        for brick in self.bricks.values() {
            brick.render(pass)
        }
    }

    pub fn global_layout(&self, bind_point: u32) -> Option<&wgpu::BindGroupLayout> {
        self.global(bind_point).map(Material::layout)
    }
}
